package com.jiao.table.listener;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.jiao.comm.entity.QueryContent;
import com.jiao.comm.exception.CloseLoopException;
import com.jiao.comm.exception.ListenerException;
import com.jiao.comm.utils.JDBCUtils;
import com.jiao.datasource.parse.SQLCommand;
import com.jiao.datasource.parse.WriteSQLUnit;
import com.jiao.table.cache.CacheWrite;
import com.jiao.table.config.CacheDataLoad;
import com.jiao.table.entity.TableInfo;
import com.jiao.table.jdbc.TableMapperService;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

import static com.jiao.datasource.parse.SQLStruct.FROM;
import static com.jiao.datasource.parse.SQLStruct.SELECT;

/**
 * @Description
 * @Author Vincent.jiao
 * @Date 2022/5/13 16:57
 */
public abstract class WriteSQLListenerAbstract extends SQLExecListenerAbstract {
    public int getWriteCount(SQLCommand sqlCommand) {
        return sqlCommand.getSqlUnit() instanceof WriteSQLUnit
                ? ((WriteSQLUnit) (sqlCommand.getSqlUnit())).getCount() : 0;
    }

    /**
     * 返回查询受影响行的id的sql.
     * @return
     */
    public String getWriteRowIdSql(SQLCommand commandInfo) {
        TableInfo tableInfo = CacheDataLoad.getTableInfo(commandInfo.getTableName());
        String sql = SELECT + " " + tableInfo.getPrimaryKeyName() + " " + FROM + " " + commandInfo.getTableName();;

        if(StrUtil.isNotEmpty(commandInfo.getWhereSql())) {
            sql += " WHERE ( " + commandInfo.getWhereSql() + " ) ";

            List<List> batchList = commandInfo.getSqlUnit().getBatchParams();
            if(CollectionUtil.isNotEmpty(batchList)) {
                for (int i = 1; i < batchList.size(); i++) {
                    sql += " or ( " + commandInfo.getWhereSql() + " ) ";
                }
            }
        }

        return sql;
    }

    public List<Long> getWriteRowId(SQLCommand commandInfo) throws SQLException {
        String sql = getWriteRowIdSql(commandInfo);
        Connection conn = commandInfo.getSqlUnit().getConnection();
        return JDBCUtils.select(conn, sql, new QueryContent<List<Long>>() {
            @Override
            public List<Long> exectu(ResultSet rs) throws SQLException {
                List<Long> ids = new LinkedList<>();
                while (rs.next()){
                    ids.add(rs.getLong(1));
                }

                return ids;
            }
        }, getWhereParam(commandInfo));
    }

    /**
     * 返回 where 后的参数.
     * @param commandInfo
     * @return Object[] 如果不是批量提交，每个元素都是设置的值。如果批量提交，每个元素都是一个数组，数组中是一批参数
     */
    public Object[] getWhereParam(SQLCommand commandInfo) {
        String whereSql = commandInfo.getWhereSql();
        char[] whereCharArr = whereSql.toCharArray();
        int count = 0;

        for (int i = 0; i < whereCharArr.length; i++) {
            count += whereCharArr[i] == '?' ? 1 : 0;
        }

        if(count == 0) {
            return new Object[0];
        }

        Object[] whereParams = null;

        if(CollectionUtil.isNotEmpty(commandInfo.getSqlUnit().getBatchParams())) {
            List<List> batchParams = commandInfo.getSqlUnit().getBatchParams();
            ArrayList paramsList = new ArrayList(batchParams.size() * batchParams.get(0).size());
//            whereParams = new Object[batchParams.size() * batchParams.get(0).size()];

            for (int i = 0; i < batchParams.size(); i++) {
                Object[] arr = getRetrievalWhereParams(batchParams.get(i), count);
                paramsList.addAll(Arrays.asList(arr));
            }

            whereParams = paramsList.toArray();

        } else {
            whereParams = getRetrievalWhereParams(commandInfo.getSqlUnit().getParams(), count);
        }

        return whereParams;
    }

    private Object[] getRetrievalWhereParams(List<Object> allParams, int count){
        Object[] whereParams = new Object[count];
        int idx = 0;
        if(allParams.size() >= count) {
            for (int i = allParams.size() - count; i < allParams.size(); i++) {
                whereParams[idx++] = allParams.get(i);
            }
        }

        return whereParams;
    }

    public List<Object> getDataById(SQLCommand commandInfo, List<Long> ids) {
        if(CollectionUtil.isEmpty(ids)) {
            return Collections.emptyList();
        }

        List dataList = new LinkedList();
        String tableName = commandInfo.getTableName();
        for (Long item : ids) {
            dataList.add(TableMapperService.getInstance().selectByPrimaryKey(tableName, item, CacheDataLoad.getTableEntityClass(tableName)));
        }

        return dataList;
    }

    /**
     * 是否执行闭环.
     * @return
     */
    public boolean isExceCloseLoop(SQLCommand sqlCommand) {
        return CacheDataLoad.getTableInfo(sqlCommand.getTableName()) != null;
    }

    @Override
    public void close(Connection conn) { }

    @Override
    public void open(Connection conn) { }

    @Override
    public void commiteBefore(List<SQLCommand> sqlCommands) throws CloseLoopException { }

    @Override
    public void commiteAfter(List<SQLCommand> sqlCommands) throws CloseLoopException { }

    @Override
    public void execBefore(SQLCommand commandInfo) { }

    @Override
    public void execAfter(SQLCommand commandInfo) { }
}
